#[burn_tensor_testgen::testgen(ad_deform_conv2d)]
mod tests {
    use super::*;
    use burn_tensor::{ElementConversion, Tolerance};
    use burn_tensor::{Shape, module::deform_conv2d, ops::DeformConvOptions};

    #[test]
    fn test_deform_conv2d_basic() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 3,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 0,
            padding_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            offset_groups: 1,
            height: 4,
            width: 4,
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [[
                    [
                        [0.000, 6.0678, 14.2071, 12.2477],
                        [11.2292, 33.7937, 50.1555, 44.0561],
                        [17.9294, 57.2174, 85.1505, 79.1840],
                        [18.0220, 73.6263, 126.8184, 151.6910],
                    ],
                    [
                        [0.000, 8.9783, 20.7620, 17.7888],
                        [16.2326, 48.7386, 71.7961, 62.5845],
                        [25.3808, 80.5195, 119.0949, 110.0938],
                        [25.0567, 101.8461, 174.3329, 206.6013],
                    ],
                ]],
                &device,
            ),
            offset: TestTensor::from_floats(
                [[
                    [[0.000, 15.0000], [30.000, 45.0000]],
                    [[0.000, 3.7500], [7.5000, 11.2500]],
                    [[62.6667, 78.3333], [94.0000, 109.6667]],
                    [[15.6667, 19.5833], [23.5000, 27.4167]],
                    [[130.6667, 104.1250], [163.3333, 122.2732]],
                    [[32.6667, -492.9583], [40.8333, -787.1620]],
                    [[204.0000, 221.0000], [238.0000, 255.0000]],
                    [[51.0000, 55.2500], [59.5000, 63.7500]],
                    [[282.6667, 300.3333], [318.0000, 335.6667]],
                    [[70.6667, 75.0833], [79.5000, 83.9167]],
                    [[366.6667, 144.3750], [403.3333, 146.4121]],
                    [[91.6667, -1788.9860], [100.8333, -2392.7456]],
                    [[456.0000, 475.0000], [-2718.6250, -2953.2188]],
                    [[114.0000, 118.7500], [37.7361, 37.4063]],
                    [[550.6667, 570.3334], [-3404.5139, -3672.5312]],
                    [[137.6667, 142.5833], [28.6806, 27.5197]],
                    [[650.6667, 27.9584], [-4174.3657, -59.7509]],
                    [[162.6667, -3991.0139], [14.4028, -298.7557]],
                ]],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [
                        [
                            [0.7029, 2.8356, 5.1067],
                            [12.7492, 19.4745, 17.8345],
                            [22.0687, 25.9156, 14.6394],
                        ],
                        [
                            [3.3696, 12.6134, 19.2671],
                            [36.7492, 50.5856, 43.5506],
                            [50.8774, 56.3292, 30.7470],
                        ],
                    ],
                    [
                        [
                            [0.7029, 2.8356, 5.1067],
                            [12.7492, 19.4745, 17.8345],
                            [22.0687, 25.9156, 14.6394],
                        ],
                        [
                            [3.3696, 12.6134, 19.2671],
                            [36.7492, 50.5856, 43.5506],
                            [50.8774, 56.3292, 30.7470],
                        ],
                    ],
                    [
                        [
                            [0.7029, 2.8356, 5.1067],
                            [12.7492, 19.4745, 17.8345],
                            [22.0687, 25.9156, 14.6394],
                        ],
                        [
                            [3.3696, 12.6134, 19.2671],
                            [36.7492, 50.5856, 43.5506],
                            [50.8774, 56.3292, 30.7470],
                        ],
                    ],
                ],
                &device,
            ),
            mask: TestTensor::from_floats(
                [[
                    [[1303.5000, 1447.8750], [1862.2500, 2006.6250]],
                    [[1571.1666, 1721.9581], [2154.7500, 2305.5417]],
                    [[1857.4999, 1396.7151], [2465.9167, 1753.2246]],
                    [[2315.5000, 2479.1250], [2948.7502, 3112.3750]],
                    [[2645.1665, 2815.2085], [3303.2500, 3473.2917]],
                    [[2993.5000, 1150.0625], [3676.4165, 1300.4055]],
                    [[3531.5000, 3714.3752], [1150.1876, 1148.4744]],
                    [[3923.1665, 4112.4585], [794.3865, 770.0470]],
                    [[4333.5000, 181.4101], [368.3260, 4.2679]],
                ]],
                &device,
            ),
            bias: TestTensor::from_floats([4., 4., 4.], &device),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_deform_conv2d_batched() {
        let test = Conv2dTestCase {
            batch_size: 2,
            channels_in: 2,
            channels_out: 3,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 0,
            padding_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            offset_groups: 1,
            height: 4,
            width: 4,
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [
                    [
                        [
                            [0.000, 3.4604, 8.7539, 6.8080],
                            [8.4661, 24.0784, 35.4610, 26.4276],
                            [19.5988, 51.0406, 68.4389, 53.4993],
                            [17.4698, 47.9106, 67.3808, 56.6063],
                        ],
                        [
                            [0.000, 5.1185, 12.7803, 9.8796],
                            [12.1957, 34.5728, 50.4616, 37.3777],
                            [27.4521, 71.1227, 94.5778, 73.4724],
                            [24.1147, 65.8443, 91.8995, 76.7475],
                        ],
                    ],
                    [
                        [
                            [6.3750, 19.3553, 26.4935, 22.5650],
                            [17.0026, 57.8088, 85.5580, 78.0746],
                            [20.7334, 86.5793, 139.4667, 136.4133],
                            [16.8126, 103.0225, 186.4502, 206.9613],
                        ],
                        [
                            [9.5625, 28.8786, 39.1137, 32.9178],
                            [25.1984, 85.0747, 124.6941, 112.5691],
                            [30.0242, 124.2863, 198.6056, 192.4489],
                            [23.5826, 143.4660, 257.8752, 283.2587],
                        ],
                    ],
                ],
                &device,
            ),

            offset: TestTensor::from_floats(
                [
                    [
                        [[0.000, 7.5000], [15.0000, 22.5000]],
                        [[0.000, 1.8750], [3.7500, 5.6250]],
                        [[31.3333, 39.1667], [47.0000, 54.8333]],
                        [[7.8333, 9.7917], [11.7500, 13.7083]],
                        [[65.3333, 62.7813], [81.6667, 75.4849]],
                        [[16.3333, -237.8021], [20.4167, -381.7280]],
                        [[102.0000, 110.5000], [119.0000, 127.5000]],
                        [[25.5000, 27.6250], [29.7500, 31.8750]],
                        [[141.3333, 150.1667], [159.0000, 167.8333]],
                        [[35.3333, 37.5417], [39.7500, 41.9583]],
                        [[183.3333, 132.3438], [201.6667, 142.0197]],
                        [[45.8333, -839.6840], [50.4167, -1133.4155]],
                        [[228.0000, 237.5000], [-1336.1562, -1452.1173]],
                        [[57.0000, 59.3750], [40.3090, 41.4141]],
                        [[275.3333, 285.1667], [-1670.5034, -1802.9244]],
                        [[68.8333, 71.2917], [44.0451, 44.9841]],
                        [[325.3333, 174.7396], [-2045.1747, -1090.4585]],
                        [[81.3333, -1844.0659], [46.8090, -1150.2101]],
                    ],
                    [
                        [[270.000, 277.5000], [285.0000, 292.5000]],
                        [[67.5000, 69.3750], [71.2500, 73.1250]],
                        [[313.3333, 321.1667], [329.0000, 336.8333]],
                        [[78.3333, 80.2917], [82.2500, 84.2083]],
                        [[359.3333, 130.1563], [375.6667, 130.6099]],
                        [[89.8333, -4312.7603], [93.9167, -4893.6035]],
                        [[408.0000, 416.5000], [425.0000, 433.5000]],
                        [[102.0000, 104.1250], [106.2500, 108.3750]],
                        [[459.3333, 468.1667], [477.0000, 485.8333]],
                        [[114.8333, 117.0417], [119.2500, 121.4583]],
                        [[513.3334, 97.9688], [531.6667, 93.8947]],
                        [[128.3333, -6720.3926], [132.9167, -7504.5405]],
                        [[570.000, 579.5000], [-7971.8438, -8251.0850]],
                        [[142.5000, 144.8750], [22.4965, 21.8203]],
                        [[629.3333, 639.1667], [-8948.2334, -9249.6641]],
                        [[157.3333, 159.7917], [15.7743, 14.8695]],
                        [[691.3333, 14.6145], [-9992.9453, -70.4040]],
                        [[172.8333, -9818.5234], [7.4132, -352.0222]],
                    ],
                ],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [
                        [
                            [77.7195, 89.8692, 69.0213],
                            [121.0760, 137.0775, 92.2989],
                            [100.0212, 106.5561, 61.1851],
                        ],
                        [
                            [112.3862, 131.6470, 103.8793],
                            [177.0760, 200.1887, 138.2681],
                            [149.5922, 158.7074, 94.3991],
                        ],
                    ],
                    [
                        [
                            [77.7195, 89.8692, 69.0213],
                            [121.0760, 137.0775, 92.2989],
                            [100.0212, 106.5561, 61.1851],
                        ],
                        [
                            [112.3862, 131.6470, 103.8793],
                            [177.0760, 200.1887, 138.2681],
                            [149.5922, 158.7074, 94.3991],
                        ],
                    ],
                    [
                        [
                            [77.7195, 89.8692, 69.0213],
                            [121.0760, 137.0775, 92.2989],
                            [100.0212, 106.5561, 61.1851],
                        ],
                        [
                            [112.3862, 131.6470, 103.8793],
                            [177.0760, 200.1887, 138.2681],
                            [149.5922, 158.7074, 94.3991],
                        ],
                    ],
                ],
                &device,
            ),
            mask: TestTensor::from_floats(
                [
                    [
                        [[1299.7499, 1439.4375], [1849.1249, 1988.8125]],
                        [[1528.0834, 1673.9791], [2101.8750, 2247.7708]],
                        [[1771.7500, 1624.9811], [2369.9583, 2099.5039]],
                        [[2183.7500, 2342.0625], [2806.3750, 2964.6875]],
                        [[2464.0833, 2628.6042], [3111.1250, 3275.6458]],
                        [[2759.7500, 1979.2551], [3431.2085, 2390.0286]],
                        [[3241.7498, 3418.6873], [2415.3589, 2500.8682]],
                        [[3574.0835, 3757.2292], [2394.3889, 2471.7510]],
                        [[3921.7500, 2095.5293], [2345.9363, 1199.5048]],
                    ],
                    [
                        [[5957.2500, 6096.9375], [6506.6250, 6646.3125]],
                        [[6392.5835, 6538.4790], [6966.3750, 7112.2705]],
                        [[6843.2500, 2443.8982], [7441.4585, 2550.9199]],
                        [[7462.2505, 7620.5625], [8084.8745, 8243.1875]],
                        [[7949.5835, 8114.1045], [8596.6250, 8761.1465]],
                        [[8452.2500, 1591.6719], [9123.7080, 1589.9454]],
                        [[9141.2500, 9318.1875], [1414.3584, 1375.1803]],
                        [[9680.5840, 9863.7285], [949.0560, 897.3544]],
                        [[10235.2500, 213.4454], [428.2699, 2.4790]],
                    ],
                ],
                &device,
            ),
            bias: TestTensor::from_floats([8., 8., 8.], &device),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_deform_conv2d_different_kernel_size() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 4,
            padding_1: 1,
            padding_2: 1,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            offset_groups: 1,
            height: 4,
            width: 4,
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [[
                    [
                        [14.558521, 27.249609, 37.382030, 36.039406],
                        [33.151936, 60.480656, 81.264656, 78.618156],
                        [57.520061, 108.623283, 153.413559, 170.072998],
                        [54.706184, 102.596664, 144.367157, 162.643570],
                    ],
                    [
                        [25.836353, 48.088451, 65.249161, 62.103317],
                        [56.805233, 102.995605, 136.983124, 131.120911],
                        [96.105408, 179.790192, 250.550934, 272.668793],
                        [90.210945, 167.567917, 232.847275, 257.934692],
                    ],
                ]],
                &device,
            ),
            offset: TestTensor::from_floats(
                [[
                    [
                        [0.0e+00, 5.355903e+00, 1.171528e+01],
                        [3.124999e-01, 8.000000e+00, 1.000000e+01],
                        [7.500000e-01, 1.400000e+01, 1.600000e+01],
                        [1.312500e+00, 2.000000e+01, 2.200000e+01],
                    ],
                    [
                        [0.0e+00, 1.736104e-03, 6.944418e-03],
                        [1.606250e+01, 2.000000e+00, 2.500000e+00],
                        [4.425000e+01, 3.500000e+00, 4.000000e+00],
                        [8.456250e+01, 5.000000e+00, 5.500000e+00],
                    ],
                    [
                        [6.745834e+01, 7.996479e+01, 9.353048e+01],
                        [3.166667e+01, 3.377778e+01, 3.588889e+01],
                        [3.800000e+01, 4.011111e+01, 4.222223e+01],
                        [4.433333e+01, 4.644444e+01, 4.855556e+01],
                    ],
                    [
                        [5.277777e-01, 5.955827e-01, 6.670526e-01],
                        [7.916667e+00, 8.444445e+00, 8.972222e+00],
                        [9.500000e+00, 1.002778e+01, 1.055556e+01],
                        [1.108333e+01, 1.161111e+01, 1.213889e+01],
                    ],
                    [
                        [1.547778e+02, 1.751640e+02, 1.518874e+02],
                        [6.000000e+01, 6.222223e+01, 4.989969e+01],
                        [6.666666e+01, 6.888889e+01, 5.432098e+01],
                        [7.333334e+01, 7.555556e+01, 5.860340e+01],
                    ],
                    [
                        [2.222223e+00, 2.363040e+00, -3.360339e+01],
                        [1.500000e+01, 1.555556e+01, -2.277485e+02],
                        [1.666667e+01, 1.722222e+01, -3.231605e+02],
                        [1.833333e+01, 1.888889e+01, -4.320448e+02],
                    ],
                    [
                        [2.641250e+02, 2.021189e+02, 0.0e+00],
                        [9.100000e+01, 6.481482e+01, 0.0e+00],
                        [9.800000e+01, 6.863078e+01, 0.0e+00],
                        [1.050000e+02, 7.230093e+01, 0.0e+00],
                    ],
                    [
                        [5.250000e+00, -7.268316e+01, 0.0e+00],
                        [2.275000e+01, -3.346296e+02, 0.0e+00],
                        [2.450000e+01, -4.611053e+02, 0.0e+00],
                        [2.625000e+01, -6.017269e+02, 0.0e+00],
                    ],
                    [
                        [4.400000e+01, 1.197778e+02, 1.222222e+02],
                        [4.804860e+01, 1.271111e+02, 1.295556e+02],
                        [5.225000e+01, 1.344444e+02, 1.368889e+02],
                        [-3.138958e+02, -8.007446e+02, -8.507313e+02],
                    ],
                    [
                        [3.377778e+02, 2.994445e+01, 3.055556e+01],
                        [4.848542e+02, 3.177778e+01, 3.238889e+01],
                        [6.467500e+02, 3.361111e+01, 3.422222e+01],
                        [4.909653e+02, 2.239892e+01, 2.265992e+01],
                    ],
                    [
                        [1.533333e+02, 1.558889e+02, 1.584444e+02],
                        [1.610000e+02, 1.635556e+02, 1.661111e+02],
                        [1.686667e+02, 1.712222e+02, 1.737778e+02],
                        [-9.952491e+02, -1.054551e+03, -1.115134e+03],
                    ],
                    [
                        [3.833333e+01, 3.897222e+01, 3.961111e+01],
                        [4.025000e+01, 4.088889e+01, 4.152778e+01],
                        [4.216667e+01, 4.280556e+01, 4.344445e+01],
                        [2.433767e+01, 2.453511e+01, 2.472810e+01],
                    ],
                    [
                        [1.920000e+02, 1.946667e+02, 8.907407e+01],
                        [2.000000e+02, 2.026667e+02, 9.054632e+01],
                        [2.080000e+02, 2.106667e+02, 9.185186e+01],
                        [-1.272938e+03, -1.343509e+03, -5.811921e+02],
                    ],
                    [
                        [4.800000e+01, 4.866667e+01, -7.413704e+02],
                        [5.000000e+01, 5.066667e+01, -9.788981e+02],
                        [5.200000e+01, 5.266667e+01, -1.232593e+03],
                        [2.531250e+01, 2.543518e+01, -6.388311e+02],
                    ],
                    [
                        [2.333333e+02, 8.772182e+01, 0.0e+00],
                        [2.416667e+02, 8.827161e+01, 0.0e+00],
                        [2.500000e+02, 8.864776e+01, 0.0e+00],
                        [-1.587216e+03, -5.535372e+02, 0.0e+00],
                    ],
                    [
                        [5.833333e+01, -9.011902e+02, 0.0e+00],
                        [6.041667e+01, -1.179988e+03, 0.0e+00],
                        [6.250000e+01, -1.475625e+03, 0.0e+00],
                        [2.489150e+01, -6.213175e+02, 0.0e+00],
                    ],
                    [
                        [1.964444e+02, 2.802222e+02, 2.831111e+02],
                        [2.055625e+02, 2.888889e+02, 2.917778e+02],
                        [-1.173472e+03, -1.679611e+03, -1.771290e+03],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [1.144889e+03, 7.005556e+01, 7.077778e+01],
                        [1.469646e+03, 7.222223e+01, 7.294444e+01],
                        [5.029167e+02, 2.298823e+01, 2.295062e+01],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [3.240000e+02, 3.270000e+02, 3.300000e+02],
                        [3.330000e+02, 3.360000e+02, 3.390000e+02],
                        [-1.931469e+03, -2.034961e+03, -2.139958e+03],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [8.100000e+01, 8.175000e+01, 8.250000e+01],
                        [8.325000e+01, 8.400000e+01, 8.475000e+01],
                        [1.959376e+01, 1.946614e+01, 1.933334e+01],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [3.733333e+02, 3.764445e+02, 4.480865e+01],
                        [3.826667e+02, 3.857778e+02, 4.185955e+01],
                        [-2.313792e+03, -2.431276e+03, -2.392101e+02],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [9.333333e+01, 9.411111e+01, -1.904932e+03],
                        [9.566667e+01, 9.644444e+01, -2.344715e+03],
                        [1.429166e+01, 1.406212e+01, -3.417283e+02],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [4.253333e+02, 1.636843e+01, 0.0e+00],
                        [4.350000e+02, 1.217279e+01, 0.0e+00],
                        [-2.738517e+03, -4.792887e+01, 0.0e+00],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [1.063333e+02, -2.178747e+03, 0.0e+00],
                        [1.087500e+02, -2.670679e+03, 0.0e+00],
                        [6.947917e+00, -1.629574e+02, 0.0e+00],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                ]],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [
                        [
                            [1.856041, 7.203409, 12.833395, 11.969448],
                            [24.236776, 40.125511, 41.396423, 27.642044],
                            [43.613083, 57.508926, 46.093338, 25.174383],
                        ],
                        [
                            [6.989914, 26.580338, 42.618557, 37.501404],
                            [75.623192, 116.925674, 113.288368, 72.567764],
                            [112.724869, 139.826447, 107.653435, 56.799385],
                        ],
                    ],
                    [
                        [
                            [1.856041, 7.203409, 12.833395, 11.969448],
                            [24.236776, 40.125511, 41.396423, 27.642044],
                            [43.613083, 57.508926, 46.093338, 25.174383],
                        ],
                        [
                            [6.989914, 26.580338, 42.618557, 37.501404],
                            [75.623192, 116.925674, 113.288368, 72.567764],
                            [112.724869, 139.826447, 107.653435, 56.799385],
                        ],
                    ],
                ],
                &device,
            ),
            mask: TestTensor::from_floats(
                [[
                    [
                        [0.0e+00, 2.677941e+00, 5.857617e+00],
                        [4.015623e+01, 7.759999e+02, 8.492499e+02],
                        [6.637500e+01, 1.067750e+03, 1.141000e+03],
                        [9.865628e+01, 1.359500e+03, 1.432750e+03],
                    ],
                    [
                        [6.745831e+01, 7.688924e+01, 8.684974e+01],
                        [8.387916e+02, 9.161111e+02, 9.934306e+02],
                        [1.146750e+03, 1.224069e+03, 1.301389e+03],
                        [1.454708e+03, 1.532028e+03, 1.609347e+03],
                    ],
                    [
                        [1.547778e+02, 1.716607e+02, 1.460455e+02],
                        [9.861667e+02, 1.067556e+03, 8.756536e+02],
                        [1.310333e+03, 1.391722e+03, 1.110864e+03],
                        [1.634500e+03, 1.715889e+03, 1.339339e+03],
                    ],
                    [
                        [2.641250e+02, 1.993876e+02, 0.0e+00],
                        [1.144875e+03, 8.365740e+02, 0.0e+00],
                        [1.485250e+03, 1.056253e+03, 0.0e+00],
                        [1.825625e+03, 1.268859e+03, 0.0e+00],
                    ],
                    [
                        [3.800000e+02, 1.047861e+03, 1.137389e+03],
                        [5.276354e+02, 1.404444e+03, 1.493972e+03],
                        [6.826807e+02, 1.761028e+03, 1.850555e+03],
                        [5.038855e+02, 1.256341e+03, 1.304936e+03],
                    ],
                    [
                        [1.123500e+03, 1.217097e+03, 1.310694e+03],
                        [1.496292e+03, 1.589889e+03, 1.683486e+03],
                        [1.869083e+03, 1.962681e+03, 2.056278e+03],
                        [1.146700e+03, 1.190136e+03, 1.232930e+03],
                    ],
                    [
                        [1.300000e+03, 1.397667e+03, 6.512036e+02],
                        [1.689000e+03, 1.786667e+03, 8.072734e+02],
                        [2.078000e+03, 2.175667e+03, 9.552593e+02],
                        [1.060781e+03, 1.097745e+03, 4.656539e+02],
                    ],
                    [
                        [1.487833e+03, 5.672195e+02, 0.0e+00],
                        [1.893042e+03, 6.972655e+02, 0.0e+00],
                        [2.298250e+03, 8.188910e+02, 0.0e+00],
                        [9.472098e+02, 3.238781e+02, 0.0e+00],
                    ],
                    [
                        [1.216444e+03, 1.792806e+03, 1.898611e+03],
                        [1.536448e+03, 2.214222e+03, 2.320028e+03],
                        [5.177084e+02, 7.256571e+02, 7.493920e+02],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [1.897500e+03, 2.007375e+03, 2.117250e+03],
                        [2.335125e+03, 2.445000e+03, 2.554875e+03],
                        [5.591096e+02, 5.750975e+02, 5.903336e+02],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [2.119333e+03, 2.233278e+03, 2.654414e+02],
                        [2.573167e+03, 2.687111e+03, 2.907444e+02],
                        [3.856317e+02, 3.924502e+02, 3.737657e+01],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                    [
                        [2.352500e+03, 9.009851e+01, 0.0e+00],
                        [2.822542e+03, 7.854909e+01, 0.0e+00],
                        [1.785990e+02, 2.930897e+00, 0.0e+00],
                        [0.0e+00, 0.0e+00, 0.0e+00],
                    ],
                ]],
                &device,
            ),
            bias: TestTensor::from_floats([12., 12.], &device),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_deform_conv2d_different_padding() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 2,
            padding_2: 3,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            offset_groups: 1,
            height: 4,
            width: 4,
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [[
                    [
                        [60.633026, 60.906506, 61.179493, 61.451954],
                        [122.557770, 123.088188, 123.618599, 124.149033],
                        [126.801132, 127.331535, 127.861938, 128.392365],
                        [131.044434, 131.574875, 132.105286, 132.635712],
                    ],
                    [
                        [102.000595, 102.497604, 102.993835, 103.489281],
                        [198.932983, 199.830597, 200.728210, 201.625870],
                        [206.113968, 207.011627, 207.909256, 208.806870],
                        [213.294952, 214.192627, 215.090271, 215.987930],
                    ],
                ]],
                &device,
            ),
            // => Position 788: 10.421875 != 10.0546875
            //  diff (rel = +1.79e-2, abs = +3.67e-1), tol (rel = +1.00e-2, abs = +9.77e-4)
            offset: TestTensor::from_floats(
                [[
                    [
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [
                            0.0, 0.0, 0.895062, 14.760561, 17.604168, 20.698063, 22.200424, 0.0,
                        ],
                        [
                            0.0, 0.0, 0.687500, 9.500000, 10.0, 10.500000, 10.108797, 0.0,
                        ],
                        [
                            0.0, 0.0, 1.113426, 13.500000, 14.000000, 14.499999, 13.645835, 0.0,
                        ],
                        [
                            0.0, 0.0, 1.613426, 17.500000, 18.000000, 18.500000, 17.108795, 0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            -12.395836,
                            -122.399445,
                            -130.752319,
                            -139.355469,
                            -131.526810,
                            0.0,
                        ],
                    ],
                    [
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [
                            0.0, 0.0, 0.154321, 0.017506, 0.020833, 0.024450, -0.387539, 0.0,
                        ],
                        [
                            0.0, 0.0, 24.187502, 2.375000, 2.500000, 2.625000, -37.863422, 0.0,
                        ],
                        [
                            0.0, 0.0, 48.057869, 3.375000, 3.500000, 3.625000, -66.770836, 0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            80.02312,
                            4.375000,
                            4.500000,
                            4.625000,
                            -103.752319,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            113.215271,
                            5.107495,
                            5.219907,
                            5.332031,
                            -139.725891,
                            0.0,
                        ],
                    ],
                    [
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [
                            0.0, 14.206017, 83.017586, 92.379395, 102.010040, 90.356323, 0.0, 0.0,
                        ],
                        [
                            0.0, 6.504737, 35.444443, 35.981483, 36.518517, 29.978970, 0.0, 0.0,
                        ],
                        [
                            0.0, 7.668316, 39.740742, 40.277779, 40.814816, 33.071907, 0.0, 0.0,
                        ],
                        [
                            0.0, 8.911458, 44.037037, 44.574074, 45.111111, 36.085281, 0.0, 0.0,
                        ],
                        [
                            0.0,
                            -57.523048,
                            -274.267914,
                            -289.547089,
                            -305.095093,
                            -248.578552,
                            0.0,
                            0.0,
                        ],
                    ],
                    [
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [
                            0.0, 9.749230, 0.955354, 0.980994, 1.006945, -13.930464, 0.0, 0.0,
                        ],
                        [
                            0.0,
                            96.046921,
                            8.861111,
                            8.995371,
                            9.129629,
                            -129.920715,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            147.434769,
                            9.935185,
                            10.069445,
                            10.203704,
                            -186.718735,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            207.494781,
                            11.009259,
                            11.143518,
                            11.277778,
                            -252.188889,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            226.050003,
                            10.153355,
                            10.252030,
                            10.350393,
                            -266.255280,
                            0.0,
                            0.0,
                        ],
                    ],
                    [
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [
                            44.224964, 159.898483, 176.651901, 193.692688, 146.270813, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            19.050755, 64.870377, 65.444443, 66.018517, 46.553150, 0.0, 0.0, 0.0,
                        ],
                        [
                            21.049385, 69.462967, 70.037033, 70.611115, 49.104595, 0.0, 0.0, 0.0,
                        ],
                        [
                            23.133059, 74.055557, 74.629631, 75.203705, 51.570988, 0.0, 0.0, 0.0,
                        ],
                        [
                            -141.200272,
                            -445.302155,
                            -468.381012,
                            -491.747223,
                            -341.553131,
                            0.0,
                            0.0,
                            0.0,
                        ],
                    ],
                    [
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [
                            35.665298, 3.505739, 3.556735, 3.608062, -48.756947, 0.0, 0.0, 0.0,
                        ],
                        [
                            181.404663,
                            16.217594,
                            16.361111,
                            16.504629,
                            -238.136124,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            263.888885,
                            17.365742,
                            17.509258,
                            17.652779,
                            -326.403656,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            355.643341,
                            18.513889,
                            18.657408,
                            18.800926,
                            -423.941345,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            318.709198,
                            14.359658,
                            14.441552,
                            14.523109,
                            -369.819580,
                            0.0,
                            0.0,
                            0.0,
                        ],
                    ],
                    [
                        [
                            0.0, 0.0, 88.846703, 237.478439, 261.731201, 286.289917, 182.508713,
                            0.0,
                        ],
                        [
                            0.0, 0.0, 37.688015, 94.722221, 95.333328, 95.944450, 57.441605, 0.0,
                        ],
                        [
                            0.0, 0.0, 40.562500, 99.611107, 100.222229, 100.833336, 59.410744, 0.0,
                        ],
                        [
                            0.0, 0.0, 43.527519, 104.500000, 105.111107, 105.722221, 61.289349, 0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            -258.324371,
                            -618.353943,
                            -649.340271,
                            -680.632507,
                            -397.101013,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0,
                            0.0,
                            76.229431,
                            7.564093,
                            7.641718,
                            7.719699,
                            -102.792252,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            272.015167,
                            23.680555,
                            23.833332,
                            23.986113,
                            -351.944214,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            386.062500,
                            24.902777,
                            25.055557,
                            25.208334,
                            -472.147888,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            509.978149,
                            26.125000,
                            26.277777,
                            26.430555,
                            -602.219971,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            378.410248,
                            17.123661,
                            17.187500,
                            17.250984,
                            -436.000732,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0, 157.623291, 331.938538, 365.283356, 398.952606, 205.988480, 0.0,
                            0.0,
                        ],
                        [
                            0.0, 66.495949, 130.925934, 131.574066, 132.222229, 64.435974, 0.0, 0.0,
                        ],
                        [
                            0.0, 70.396835, 136.111115, 136.759262, 137.407410, 65.672256, 0.0, 0.0,
                        ],
                        [
                            0.0, 74.393753, 141.296295, 141.944458, 142.592606, 66.812523, 0.0, 0.0,
                        ],
                        [
                            0.0,
                            -432.798035,
                            -827.492065,
                            -867.978455,
                            -908.789368,
                            -425.074158,
                            0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0,
                            140.150024,
                            14.043960,
                            14.152921,
                            14.262260,
                            -187.656906,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            386.813873,
                            32.731483,
                            32.893517,
                            33.055557,
                            -494.779602,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            538.926697,
                            34.027779,
                            34.189816,
                            34.351852,
                            -653.421875,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            701.505859,
                            35.324074,
                            35.486115,
                            35.648151,
                            -822.530640,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            416.044586,
                            18.903570,
                            18.944647,
                            18.985338,
                            -476.728790,
                            0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            249.876541, 435.868500, 479.178772, 522.832031, 207.919815, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            105.417015, 170.611115, 171.296295, 171.981476, 64.750000, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            110.441696, 176.092590, 176.777771, 177.462952, 65.156044, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            115.567902, 181.574066, 182.259247, 182.944458, 65.460571, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            -662.743530,
                            -1056.641846,
                            -1107.501953,
                            -1158.704712,
                            -409.510162,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            227.160507,
                            22.982454,
                            23.125793,
                            23.269531,
                            -303.112030,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            518.495178,
                            42.652779,
                            42.824074,
                            42.995369,
                            -657.157410,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            712.252380,
                            44.023148,
                            44.194443,
                            44.365738,
                            -857.817200,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            917.074036,
                            45.393517,
                            45.564812,
                            45.736115,
                            -1069.541626,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            416.581482,
                            18.997831,
                            19.013102,
                            19.027966,
                            -475.031525,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0, 0.0, 151.750259, 210.166672, 210.888885, 211.611099, 57.506927,
                            0.0,
                        ],
                        [
                            0.0, 0.0, 157.929276, 215.944443, 216.666672, 217.388901, 57.052204,
                            0.0,
                        ],
                        [
                            0.0, 0.0, 164.215271, 221.722229, 222.444458, 223.166672, 56.490482,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            -931.783752,
                            -1285.353760,
                            -1346.555908,
                            -1408.119385,
                            -346.739044,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0,
                            0.0,
                            655.669983,
                            52.541668,
                            52.722221,
                            52.902775,
                            -824.946777,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            890.972473,
                            53.986111,
                            54.166668,
                            54.347225,
                            -1067.525024,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            1137.937500,
                            55.430557,
                            55.611115,
                            55.791668,
                            -1321.765625,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            375.580566,
                            17.180984,
                            17.169498,
                            17.157579,
                            -425.993713,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0, 213.521454, 256.629639, 257.388885, 258.148132, 41.652927, 0.0,
                            0.0,
                        ],
                        [
                            0.0, 221.015625, 262.703705, 263.462982, 264.222229, 40.176598, 0.0,
                            0.0,
                        ],
                        [
                            0.0, 228.622284, 268.777802, 269.537048, 270.296295, 38.587788, 0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            -1285.466797,
                            -1554.254517,
                            -1627.530640,
                            -1701.186646,
                            -228.291397,
                            0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0,
                            823.380554,
                            64.157410,
                            64.347221,
                            64.537033,
                            -1028.532715,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            1107.296509,
                            65.675926,
                            65.865746,
                            66.055557,
                            -1320.097534,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            1403.473022,
                            67.194450,
                            67.384262,
                            67.574074,
                            -1623.922974,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            288.151398,
                            13.201796,
                            13.158524,
                            13.114797,
                            -323.577820,
                            0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            288.790131, 306.574066, 307.370361, 308.166656, 15.734239, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            297.696838, 312.944427, 313.740723, 314.537048, 13.138914, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            306.721527, 319.314819, 320.111115, 320.907410, 10.425544, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            -1711.543213,
                            -1844.013062,
                            -1930.236572,
                            -2016.858643,
                            -46.846100,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            1011.358093,
                            76.643517,
                            76.842590,
                            77.041664,
                            -1255.045654,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            1347.466431,
                            78.236107,
                            78.435181,
                            78.634262,
                            -1599.175903,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            1696.433350,
                            79.828705,
                            80.027779,
                            80.226852,
                            -1956.164917,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            146.703568,
                            6.690874,
                            6.612756,
                            6.534196,
                            -159.277222,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                ]],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [
                        [
                            [10.341997, 22.988085, 35.634174],
                            [46.920216, 59.566299, 72.212387],
                            [80.881615, 92.591522, 104.158524],
                        ],
                        [
                            [29.213360, 68.837769, 108.462166],
                            [143.825104, 183.449509, 223.073944],
                            [228.029373, 256.751740, 283.807098],
                        ],
                    ],
                    [
                        [
                            [10.341997, 22.988085, 35.634174],
                            [46.920216, 59.566299, 72.212387],
                            [80.881615, 92.591522, 104.158524],
                        ],
                        [
                            [29.213360, 68.837769, 108.462166],
                            [143.825104, 183.449509, 223.073944],
                            [228.029373, 256.751740, 283.807098],
                        ],
                    ],
                ],
                &device,
            ),
            mask: TestTensor::from_floats(
                [[
                    [
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [
                            0.0, 0.0, 0.447531, 7.380288, 8.802088, 10.349031, 11.100212, 0.0,
                        ],
                        [
                            0.0, 0.0, 44.343754, 584.937439, 639.250000, 693.562439, 683.262756,
                            0.0,
                        ],
                        [
                            0.0, 0.0, 68.390068, 803.437561, 857.750000, 912.062500, 874.698059,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            96.473381,
                            1021.937500,
                            1076.250000,
                            1130.562500,
                            1062.095947,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            121.302101,
                            1168.487915,
                            1218.373779,
                            1268.134888,
                            1169.444702,
                            0.0,
                        ],
                    ],
                    [
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [
                            0.0, 13.084491, 75.860909, 83.767761, 91.809029, 80.728188, 0.0, 0.0,
                        ],
                        [
                            0.0, 118.950417, 649.486084, 707.821777, 766.157410, 658.076599, 0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            170.660782,
                            884.171265,
                            942.506958,
                            1000.842651,
                            837.809326,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            226.707260,
                            1118.856445,
                            1177.192261,
                            1235.527710,
                            1013.205933,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            234.939651,
                            1106.213867,
                            1153.415649,
                            1200.482666,
                            966.248901,
                            0.0,
                            0.0,
                        ],
                    ],
                    [
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [
                            42.524002, 153.045700, 168.319275, 183.736511, 138.144653, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            207.319611, 718.432800, 780.791626, 843.150391, 619.975037, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            290.277802,
                            969.303223,
                            1031.661987,
                            1094.020752,
                            784.421631,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            377.871063,
                            1220.173584,
                            1282.532471,
                            1344.891235,
                            944.233032,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            328.083038,
                            1025.494995,
                            1069.130615,
                            1112.622192,
                            766.054932,
                            0.0,
                            0.0,
                            0.0,
                        ],
                    ],
                    [
                        [
                            0.0, 0.0, 88.238174, 235.055206, 258.194336, 281.486389, 178.858536,
                            0.0,
                        ],
                        [
                            0.0, 0.0, 305.575500, 789.868042, 856.250061, 922.631897, 572.466064,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            421.809021,
                            1056.923584,
                            1123.305542,
                            1189.687500,
                            719.598816,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            542.976746,
                            1323.979248,
                            1390.361206,
                            1456.743042,
                            861.797302,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            393.291565,
                            934.439697,
                            974.010376,
                            1013.428101,
                            586.924011,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0, 157.214920, 330.227448, 362.473419, 394.881653, 203.374420, 0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            424.340576,
                            867.495361,
                            937.900452,
                            1008.305542,
                            505.640503,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            578.894897,
                            1150.736084,
                            1221.141235,
                            1291.546265,
                            630.414001,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            738.682495,
                            1433.976929,
                            1504.381958,
                            1574.787109,
                            749.954346,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0, 429.912781, 816.507507, 850.771973, 884.873779, 411.152588, 0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            249.876541, 434.964233, 477.198730, 519.604675, 206.215576, 0.0, 0.0,
                            0.0,
                        ],
                        [
                            560.309326,
                            949.520813,
                            1023.949097,
                            1098.377319,
                            422.458344,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            756.768127,
                            1248.946777,
                            1323.375000,
                            1397.803223,
                            521.289001,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            958.759216,
                            1548.372803,
                            1622.800903,
                            1697.229248,
                            614.587402,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            428.833923, 679.269775, 707.346252, 735.250916, 258.169373, 0.0, 0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0,
                            0.0,
                            707.671387,
                            1033.687378,
                            1112.138916,
                            1190.590210,
                            328.295044,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            947.779419,
                            1349.298584,
                            1427.750000,
                            1506.201416,
                            399.438080,
                            0.0,
                        ],
                        [
                            0.0,
                            0.0,
                            1193.718872,
                            1664.909668,
                            1743.361084,
                            1821.812500,
                            464.749847,
                            0.0,
                        ],
                        [
                            0.0, 0.0, 388.737854, 532.503540, 553.962891, 575.241089, 140.658310,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            0.0,
                            880.797302,
                            1124.393555,
                            1206.868042,
                            1289.342651,
                            209.627625,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            1169.882812,
                            1456.189819,
                            1538.664429,
                            1621.138916,
                            247.754730,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0,
                            1465.098755,
                            1787.986084,
                            1870.460571,
                            1952.935181,
                            279.751526,
                            0.0,
                            0.0,
                        ],
                        [
                            0.0, 297.330719, 356.362152, 369.893524, 383.234344, 50.974621, 0.0,
                            0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                    [
                        [
                            1074.567993,
                            1219.497681,
                            1305.995361,
                            1392.493042,
                            71.162437,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            1416.214722,
                            1567.479126,
                            1653.976929,
                            1740.474609,
                            72.689949,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            1764.290771,
                            1915.460571,
                            2001.958496,
                            2088.456055,
                            67.787628,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        [
                            151.018372, 160.055023, 164.776138, 169.298447, 3.865937, 0.0, 0.0, 0.0,
                        ],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                    ],
                ]],
                &device,
            ),
            bias: TestTensor::from_floats([48., 48.], &device),
        };
        test.assert_grads(grads);
    }

    struct Conv2dTestCase {
        batch_size: usize,
        channels_in: usize,
        channels_out: usize,
        kernel_size_1: usize,
        kernel_size_2: usize,
        padding_1: usize,
        padding_2: usize,
        stride_1: usize,
        stride_2: usize,
        dilation_1: usize,
        dilation_2: usize,
        groups: usize,
        offset_groups: usize,
        height: usize,
        width: usize,
    }

    struct Grads {
        x: TestTensor<4>,
        offset: TestTensor<4>,
        weight: TestTensor<4>,
        mask: TestTensor<4>,
        bias: TestTensor<1>,
    }

    impl Conv2dTestCase {
        fn assert_grads(self, expected_grads: Grads) {
            let out_height =
                (self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1)
                    / self.stride_1
                    + 1;
            let out_width =
                (self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1)
                    / self.stride_2
                    + 1;

            let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
            let shape_offset = Shape::new([
                self.batch_size,
                2 * self.offset_groups * self.kernel_size_1 * self.kernel_size_2,
                out_height,
                out_width,
            ]);
            let shape_weight = Shape::new([
                self.channels_out,
                self.channels_in / self.groups,
                self.kernel_size_1,
                self.kernel_size_2,
            ]);
            let shape_mask = Shape::new([
                self.batch_size,
                self.offset_groups * self.kernel_size_1 * self.kernel_size_2,
                out_height,
                out_width,
            ]);
            let device = Default::default();
            let weight = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_weight)
                    .into_data(),
                &device,
            )
            .require_grad();
            let bias = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
                &device,
            )
            .require_grad();
            let x = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_x)
                    .into_data(),
                &device,
            )
            .require_grad();
            let offset = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_offset.clone())
                    .into_data(),
                &device,
            )
            .div_scalar(shape_offset.num_elements() as f32)
            .require_grad();

            let mask = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_mask.clone())
                    .into_data(),
                &device,
            )
            .div_scalar(shape_mask.num_elements() as f32)
            .require_grad();

            let output = deform_conv2d(
                x.clone(),
                offset.clone(),
                weight.clone(),
                Some(mask.clone()),
                Some(bias.clone()),
                DeformConvOptions::new(
                    [self.stride_1, self.stride_2],
                    [self.padding_1, self.padding_2],
                    [self.dilation_1, self.dilation_2],
                    self.groups,
                    self.offset_groups,
                ),
            );
            let grads = output.backward();

            // Assert
            let x_grad_actual = x.grad(&grads).unwrap();
            let offset_grad_actual = offset.grad(&grads).unwrap();
            let weight_grad_actual = weight.grad(&grads).unwrap();
            let mask_grad_actual = mask.grad(&grads).unwrap();
            let bias_grad_actual = bias.grad(&grads).unwrap();

            // Relative is set to 5%, which is much higher than typical numerical test tolerances.
            // This is due to the complexity of the deformable convolution operation.
            // Unlike regular conv2d, which samples from fixed integer grid positions,
            // deformable conv2d samples input values at fractional offset locations (learned offsets).
            // These non-integer positions require bilinear interpolation to estimate the input value.
            // Gradients computed through all these floating-point operations can compound numerical differences.
            let tolerance = Tolerance::relative(0.5);

            println!("Testing bias");
            expected_grads
                .bias
                .to_data()
                .assert_approx_eq::<FloatType>(&bias_grad_actual.to_data(), tolerance);
            println!("Testing input");
            expected_grads
                .x
                .to_data()
                .assert_approx_eq::<FloatType>(&x_grad_actual.to_data(), tolerance);
            println!("Testing offset");
            expected_grads
                .offset
                .to_data()
                .assert_approx_eq::<FloatType>(&offset_grad_actual.to_data(), tolerance);
            println!("Testing mask");
            expected_grads
                .mask
                .to_data()
                .assert_approx_eq::<FloatType>(&mask_grad_actual.to_data(), tolerance);
            println!("Testing weight");
            expected_grads
                .weight
                .to_data()
                .assert_approx_eq::<FloatType>(&weight_grad_actual.to_data(), tolerance);
        }
    }
}
